{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importance of decision tree hyperparameters on generalization\n",
    "\n",
    "In this notebook, we illustrate the importance of some key hyperparameters on\n",
    "the decision tree; we demonstrate their effects on the classification and\n",
    "regression problems we saw previously.\n",
    "\n",
    "First, we load the classification and regression datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data_clf_columns = [\"Culmen Length (mm)\", \"Culmen Depth (mm)\"]\n",
    "target_clf_column = \"Species\"\n",
    "data_clf = pd.read_csv(\"../datasets/penguins_classification.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_reg_columns = [\"Flipper Length (mm)\"]\n",
    "target_reg_column = \"Body Mass (g)\"\n",
    "data_reg = pd.read_csv(\"../datasets/penguins_regression.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"admonition note alert alert-info\">\n",
    "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
    "<p class=\"last\">If you want a deeper overview regarding this dataset, you can refer to the\n",
    "Appendix - Datasets description section at the end of this MOOC.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create helper functions\n",
    "\n",
    "We create some helper functions to plot the data samples as well as the\n",
    "decision boundary for classification and the regression line for regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.inspection import DecisionBoundaryDisplay\n",
    "\n",
    "\n",
    "def fit_and_plot_classification(model, data, feature_names, target_names):\n",
    "    model.fit(data[feature_names], data[target_names])\n",
    "    if data[target_names].nunique() == 2:\n",
    "        palette = [\"tab:red\", \"tab:blue\"]\n",
    "    else:\n",
    "        palette = [\"tab:red\", \"tab:blue\", \"black\"]\n",
    "    DecisionBoundaryDisplay.from_estimator(\n",
    "        model,\n",
    "        data[feature_names],\n",
    "        response_method=\"predict\",\n",
    "        cmap=\"RdBu\",\n",
    "        alpha=0.5,\n",
    "    )\n",
    "    sns.scatterplot(\n",
    "        data=data,\n",
    "        x=feature_names[0],\n",
    "        y=feature_names[1],\n",
    "        hue=target_names,\n",
    "        palette=palette,\n",
    "    )\n",
    "    plt.legend(bbox_to_anchor=(1.05, 1), loc=\"upper left\")\n",
    "\n",
    "\n",
    "def fit_and_plot_regression(model, data, feature_names, target_names):\n",
    "    model.fit(data[feature_names], data[target_names])\n",
    "    data_test = pd.DataFrame(\n",
    "        np.arange(data.iloc[:, 0].min(), data.iloc[:, 0].max()),\n",
    "        columns=data[feature_names].columns,\n",
    "    )\n",
    "    target_predicted = model.predict(data_test)\n",
    "\n",
    "    sns.scatterplot(\n",
    "        x=data.iloc[:, 0], y=data[target_names], color=\"black\", alpha=0.5\n",
    "    )\n",
    "    plt.plot(data_test.iloc[:, 0], target_predicted, linewidth=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Effect of the `max_depth` parameter\n",
    "\n",
    "The hyperparameter `max_depth` controls the overall complexity of a decision\n",
    "tree. This hyperparameter allows to get a trade-off between an under-fitted\n",
    "and over-fitted decision tree. Let's build a shallow tree and then a deeper\n",
    "tree, for both classification and regression, to understand the impact of the\n",
    "parameter.\n",
    "\n",
    "We can first set the `max_depth` parameter value to a very low value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor\n",
    "\n",
    "max_depth = 2\n",
    "tree_clf = DecisionTreeClassifier(max_depth=max_depth)\n",
    "tree_reg = DecisionTreeRegressor(max_depth=max_depth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(f\"Shallow classification tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_regression(\n",
    "    tree_reg, data_reg, data_reg_columns, target_reg_column\n",
    ")\n",
    "_ = plt.title(f\"Shallow regression tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's increase the `max_depth` parameter value to check the difference by\n",
    "observing the decision function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_depth = 30\n",
    "tree_clf = DecisionTreeClassifier(max_depth=max_depth)\n",
    "tree_reg = DecisionTreeRegressor(max_depth=max_depth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(f\"Deep classification tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_regression(\n",
    "    tree_reg, data_reg, data_reg_columns, target_reg_column\n",
    ")\n",
    "_ = plt.title(f\"Deep regression tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For both classification and regression setting, we observe that increasing the\n",
    "depth makes the tree model more expressive. However, a tree that is too deep\n",
    "may overfit the training data, creating partitions which are only correct for\n",
    "\"outliers\" (noisy samples). The `max_depth` is one of the hyperparameters that\n",
    "one should optimize via cross-validation and grid-search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "param_grid = {\"max_depth\": np.arange(2, 10, 1)}\n",
    "tree_clf = GridSearchCV(DecisionTreeClassifier(), param_grid=param_grid)\n",
    "tree_reg = GridSearchCV(DecisionTreeRegressor(), param_grid=param_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(\n",
    "    f\"Optimal depth found via CV: {tree_clf.best_params_['max_depth']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_and_plot_regression(\n",
    "    tree_reg, data_reg, data_reg_columns, target_reg_column\n",
    ")\n",
    "_ = plt.title(\n",
    "    f\"Optimal depth found via CV: {tree_reg.best_params_['max_depth']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this example, we see that there is not a single value that is optimal for\n",
    "any dataset. Thus, this parameter is required to be optimized for each\n",
    "application.\n",
    "\n",
    "## Other hyperparameters in decision trees\n",
    "\n",
    "The `max_depth` hyperparameter controls the overall complexity of the tree.\n",
    "This parameter is adequate under the assumption that a tree is built\n",
    "symmetrically. However, there is no reason why a tree should be symmetrical.\n",
    "Indeed, optimal generalization performance could be reached by growing some of\n",
    "the branches deeper than some others.\n",
    "\n",
    "We build a dataset where we illustrate this asymmetry. We generate a dataset\n",
    "composed of 2 subsets: one subset where a clear separation should be found by\n",
    "the tree and another subset where samples from both classes are mixed. It\n",
    "implies that a decision tree needs more splits to classify properly samples\n",
    "from the second subset than from the first subset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_blobs\n",
    "\n",
    "data_clf_columns = [\"Feature #0\", \"Feature #1\"]\n",
    "target_clf_column = \"Class\"\n",
    "\n",
    "# Blobs that are interlaced\n",
    "X_1, y_1 = make_blobs(\n",
    "    n_samples=300, centers=[[0, 0], [-1, -1]], random_state=0\n",
    ")\n",
    "# Blobs that can be easily separated\n",
    "X_2, y_2 = make_blobs(n_samples=300, centers=[[3, 6], [7, 0]], random_state=0)\n",
    "\n",
    "X = np.concatenate([X_1, X_2], axis=0)\n",
    "y = np.concatenate([y_1, y_2])\n",
    "data_clf = np.concatenate([X, y[:, np.newaxis]], axis=1)\n",
    "data_clf = pd.DataFrame(\n",
    "    data_clf, columns=data_clf_columns + [target_clf_column]\n",
    ")\n",
    "data_clf[target_clf_column] = data_clf[target_clf_column].astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.scatterplot(\n",
    "    data=data_clf,\n",
    "    x=data_clf_columns[0],\n",
    "    y=data_clf_columns[1],\n",
    "    hue=target_clf_column,\n",
    "    palette=[\"tab:red\", \"tab:blue\"],\n",
    ")\n",
    "_ = plt.title(\"Synthetic dataset\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first train a shallow decision tree with `max_depth=2`. We would expect\n",
    "this depth to be enough to separate the blobs that are easy to separate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_depth = 2\n",
    "tree_clf = DecisionTreeClassifier(max_depth=max_depth)\n",
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(f\"Decision tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, we see that the blue blob in the lower right and the red blob on\n",
    "the top are easily separated. However, more splits are required to better\n",
    "split the blob were both blue and red data points are mixed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import plot_tree\n",
    "\n",
    "_, ax = plt.subplots(figsize=(10, 10))\n",
    "_ = plot_tree(tree_clf, ax=ax, feature_names=data_clf_columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the right branch achieves perfect classification. Now, we increase\n",
    "the depth to check how the tree grows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_depth = 6\n",
    "tree_clf = DecisionTreeClassifier(max_depth=max_depth)\n",
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(f\"Decision tree with max-depth of {max_depth}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(figsize=(11, 7))\n",
    "_ = plot_tree(tree_clf, ax=ax, feature_names=data_clf_columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, the left branch of the tree continue to grow while no further\n",
    "splits were done on the right branch. Fixing the `max_depth` parameter would\n",
    "cut the tree horizontally at a specific level, whether or not it would be more\n",
    "beneficial that a branch continue growing.\n",
    "\n",
    "The hyperparameters `min_samples_leaf`, `min_samples_split`, `max_leaf_nodes`,\n",
    "or `min_impurity_decrease` allow growing asymmetric trees and apply a\n",
    "constraint at the leaves or nodes level. We check the effect of\n",
    "`min_samples_leaf`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_samples_leaf = 60\n",
    "tree_clf = DecisionTreeClassifier(min_samples_leaf=min_samples_leaf)\n",
    "fit_and_plot_classification(\n",
    "    tree_clf, data_clf, data_clf_columns, target_clf_column\n",
    ")\n",
    "_ = plt.title(\n",
    "    f\"Decision tree with leaf having at least {min_samples_leaf} samples\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(figsize=(10, 7))\n",
    "_ = plot_tree(tree_clf, ax=ax, feature_names=data_clf_columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This hyperparameter allows to have leaves with a minimum number of samples and\n",
    "no further splits are searched otherwise. Therefore, these hyperparameters\n",
    "could be an alternative to fix the `max_depth` hyperparameter."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}